/*    Copyright 2010 Tobias Marschall
 *
 *    This file is part of MoSDi.
 *
 *    MoSDi is free software: you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation, either version 3 of the License, or
 *    (at your option) any later version.
 *
 *    MoSDi is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with MoSDi.  If not, see <http://www.gnu.org/licenses/>.
 */

package mosdi.distributions;

import java.util.Arrays;

import mosdi.util.Combinatorics;
import mosdi.util.Log;
import mosdi.util.LogSpace;

public class PoissonDistribution {
	private double lambda;
	private double logLambda;
	
	public PoissonDistribution(double lambda) {
		this.lambda=lambda;
		this.logLambda=Math.log(lambda);
	}
	
	public double getLogProb(int k) {
		return -lambda + logLambda*k - Combinatorics.logFactorial(k);
	}
	
	public double getProb(int k) {
		return Math.exp(getLogProb(k));
	}
	
	/** Returns the distribution in the range 0..size-1. 
	 * @param accumulateTail If true, the last value, i.e. result[size-1] 
	 *                       contains the summed up tail probability such that
	 *                       the whole distribution sums up to 1. 
	 */
	public double[] get(int size, boolean accumulateTail) {
		double[] result = new double[size]; 
		for (int i=0; i<size-1; ++i) result[i] = getProb(i);
		double p = 0.0;
		if (accumulateTail) {
			for (int i=pvaluePrecisionLimit(); i>=size-1; --i) {
				p += getProb(i);
			}
		} else {
			p = getProb(size-1);
		}
		result[size-1] = p;
		return result;
	}
	
	/** Returns the smallest value n, such that p(n)<Double.MIN_VALUE. Uses
	 *  geometric search.
	 */
	public int pvaluePrecisionLimit() {
		int n = 1 + (int)lambda;
		double logMinValue = Math.log(Double.MIN_VALUE);
		while (getLogProb(n)>logMinValue) n*=2;
		// search interval [n/2, n] using binary search
		int l=n/2;
		int r=n;
		while (l+1<r) {
			int m=(l+r)/2;
			if (getLogProb(m)>logMinValue) l=m;
			else r=m;
		}
		return r;
	}
	
	/** Returns the largest k such that p(X>=k)>pvalue. */
	public int getQuantileByPValue(double pvalue) {
		int n = pvaluePrecisionLimit();
		// probability that p(X>n)
		double p = 0.0;
		for (;n>=0;--n) {
			p+=getProb(n);
			if (p>pvalue) return n; 
		}
		return 0;
	}
	
	public void check(int k) {
		double sum = 0.0;
		for (int i=0; i<=k; ++i) {
			sum+=getProb(i);
		}
		if (sum>1.0) throw new IllegalStateException(String.format("BUG: pSum=%e",sum));
	}
	
	/** Calculates the probability that the compound poisson distribution given by 
	 *  componentDist and lambda takes a value>=k. */
	public double compoundPoissonPValue(double[] componentDist, int k) {
		if (k<=0) return 1.0;
		if (componentDist[0]!=0.0) throw new IllegalArgumentException("Not implemented for componentDist[0]!=0.0");
		// Remove zeros at the tail of component distribution
		int rightmostNonZero = componentDist.length-1;
		while (componentDist[rightmostNonZero]==0.0) {
			rightmostNonZero -= 1;
			if (rightmostNonZero < 0) throw new IllegalArgumentException("Component distributions contains only zeros.");
		}
		if (rightmostNonZero < componentDist.length-1) {
			componentDist = Arrays.copyOf(componentDist, rightmostNonZero+1);	
		}
		// current distribution resulting from i-fold convolution of componentDist
		// i.e. dist[x]=P(#matches=x|#clumps=i)
		// dist[k] has a special meaning: dist[k]=P(#matches>=k|#clumps=i)
		double[] dist = Arrays.copyOf(componentDist, k+1);
		for (int i=k+1; i<componentDist.length; ++i) {
			dist[k] += componentDist[i];
		}
		double p = getProb(1) * dist[k];
		for (int componentNumber=2; componentNumber<=k; ++componentNumber) {
			// Log.println(Log.Level.DEBUG, String.format("%d: %e ", i, p));
			// save probability mass about to leave the table
			double pLeave = 0.0;
			// convolve componentDist onto dist
			for (int j=k+componentDist.length-1; j>0; --j) {
				// recalculate dist[j]
				double dist_j = 0.0;
				for (int a=1; a<componentDist.length; ++a) {
					if (j-a<0) break;
					if (j-a<=k) dist_j+=componentDist[a]*dist[j-a];
				}
				if (j<=k) dist[j]=dist_j;
				else pLeave+=dist_j;
			}
			// add probability dist[k]=P(#matches>=k|#clumps=i-1) as it would otherwise
			// "leave" the table
			dist[k]+=pLeave;
			double pPoisson = getProb(componentNumber); 
			p+= pPoisson * dist[k];
		}
		// add probability to have more than k clumps, which implies value>=k 
		// (by the assumption that componentDist[0]=0).
		// Log.println(Log.Level.DEBUG, "===");
		for (int i=k+1;;++i) {
//			Log.println(Log.Level.DEBUG, String.format("%d: %e ", i, p));
			double p_i = getProb(i);
			p+=p_i;
			if (i<lambda) continue;
			if (p_i<1e-300) break;
			if (p/p_i>1e14) break;
		}
		// if ((p<0.0)||(p>1.0)) throw new IllegalStateException(String.format("BUG: p=%e",p));
		return p;
	}

	/** Calculates the log-probability that the compound poisson distribution given by 
	 *  componentDist and lambda takes a value>=k. */
	public double logCompoundPoissonPValue(double[] componentDist, int k) {
		if (k<=0) return 0.0;
		if (componentDist[0]!=0.0) throw new IllegalArgumentException("Not implemented for componentDist[0]!=0.0");
		// Remove zeros at the tail of component distribution
		int rightmostNonZero = componentDist.length-1;
		while (componentDist[rightmostNonZero]==0.0) {
			rightmostNonZero -= 1;
			if (rightmostNonZero < 0) throw new IllegalArgumentException("Component distributions contains only zeros.");
		}
		double[] logComponentDist = new double[rightmostNonZero+1];
		for (int i=0; i<logComponentDist.length; ++i) logComponentDist[i]=Math.log(componentDist[i]);
		// current distribution resulting from i-fold convolution of componentDist
		// i.e. dist[x]=P(#matches=x|#clumps=i)
		// dist[k] has a special meaning: dist[k]=P(#matches>=k|#clumps=i)
		double[] logDist = new double[k+1];
		Arrays.fill(logDist, Double.NEGATIVE_INFINITY);
		System.arraycopy(logComponentDist, 0, logDist, 0, Math.min(logComponentDist.length,logDist.length));
		for (int i=k+1; i<componentDist.length; ++i) {
			logDist[k] = LogSpace.logAdd(logDist[k], Math.log(componentDist[i]));
		}
		double p = getLogProb(1) + logDist[k];
		for (int componentNumber=2; componentNumber<=k; ++componentNumber) {
			// Log.println(Log.Level.DEBUG, String.format("%d: %e ", i, p));
			double pLeave = Double.NEGATIVE_INFINITY;
			// convolve componentDist onto dist
			for (int j=k+logComponentDist.length-1; j>0; --j) {
				// recalculate dist[j]
				double dist_j = Double.NEGATIVE_INFINITY;
				for (int a=1; a<logComponentDist.length; ++a) {
					if (j-a<0) break;
					if (j-a<=k) dist_j=LogSpace.logAdd(dist_j, logComponentDist[a]+logDist[j-a]);
				}
				if (j<=k) logDist[j]=dist_j;
				else pLeave = LogSpace.logAdd(pLeave, dist_j);
			}
			// add probability dist[k]=P(#matches>=k|#clumps=i-1) as it would otherwise
			// "leave" the table
			logDist[k]=LogSpace.logAdd(logDist[k], pLeave);
			p=LogSpace.logAdd(p, getLogProb(componentNumber) + logDist[k]);
		}
		// add probability to have more than k clumps, which implies value>=k 
		// (by the assumption that componentDist[0]=0).
		// Log.println(Log.Level.DEBUG, "===");
		for (int i=k+1;;++i) {
//			Log.println(Log.Level.DEBUG, String.format("%d: %e ", i, p));
			double p_i = getLogProb(i);
			p=LogSpace.logAdd(p, p_i);
			if (i<lambda) continue;
			if (p-p_i>Math.log(1e14)) break;
		}
		// if ((p<0.0)||(p>1.0)) throw new IllegalStateException(String.format("BUG: p=%e",p));
		return p;
	}

	
	/** Calculates the compound poisson distribution given by componentDist and lambda
	 *  explicitly, up to k, i.e. the returned distribution will have a length of k+1.
	 *  The last entry has the meaning: P(X>=k), therefore, the table sums up to 1. */
	public double[] compoundPoissonDistribution(double[] componentDist, int k) {
		if (k<=0) {
			double[] result = {1.0};
			return result;
		}
		if (componentDist[0]!=0.0) throw new IllegalArgumentException("Not implemented for componentDist[0]!=0.0");
		// Remove zeros at the tail of component distribution
		int rightmostNonZero = componentDist.length-1;
		while (componentDist[rightmostNonZero]==0.0) {
			rightmostNonZero -= 1;
			if (rightmostNonZero < 0) throw new IllegalArgumentException("Component distributions contains only zeros.");
		}
		Log.startTimer();
		if (rightmostNonZero < componentDist.length-1) {
			componentDist = Arrays.copyOf(componentDist, rightmostNonZero+1);	
		}
		// current distribution resulting from i-fold convolution of componentDist
		// i.e. dist[x]=P(#matches=x|#clumps=i)
		// dist[k] has a special meaning: dist[k]=P(#matches>=k|#clumps=i)
		double[] dist = Arrays.copyOf(componentDist, k+1);
		for (int i=k+1; i<componentDist.length; ++i) {
			dist[k] += componentDist[i];
		}
		// initialize result
		double[] result = new double[k+1];
		result[0] = getProb(0);
		double pPoisson = getProb(1);
		for (int j=0; j<=k; ++j) result[j]+=pPoisson * dist[j];
		for (int componentNumber=2; componentNumber<=k; ++componentNumber) {
			// Log.println(Log.Level.DEBUG, String.format("%d: %e ", i, p));
			// save probability mass about to leave the table
			double pLeave = 0.0;
			// convolve componentDist onto dist
			for (int j=k+componentDist.length-1; j>0; --j) {
				// recalculate dist[j]
				double dist_j = 0.0;
				for (int a=1; a<componentDist.length; ++a) {
					if (j-a<0) break;
					if (j-a<=k) dist_j+=componentDist[a]*dist[j-a];
				}
				if (j<=k) dist[j]=dist_j;
				else pLeave+=dist_j;
			}
			// add probability dist[k]=P(#matches>=k|#clumps=i-1) as it would otherwise
			// "leave" the table
			dist[k]+=pLeave;
			pPoisson = getProb(componentNumber); 
			for (int j=0; j<=k; ++j) result[j]+=pPoisson * dist[j];
		}
		// add probability to have more than k clumps, which implies value>=k 
		// (by the assumption that componentDist[0]=0).
		for (int i=k+1;;++i) {
			double p_i = getProb(i);
			result[k]+=p_i;
			if (i<lambda) continue;
			if (p_i<1e-300) break;
			if (result[k]/p_i>1e14) break;
		}
		Log.stopTimer("Calculate distribution (compound poisson)");
		return result;
	}

	/** Like compoundPoissonDistribution, but in logarithmic domain. */
	public double[] logCompoundPoissonDistribution(double[] componentDist, int k) {
		if (k<=0) return new double[1];
		if (componentDist[0]!=0.0) throw new IllegalArgumentException("Not implemented for componentDist[0]!=0.0");
		// Remove zeros at the tail of component distribution
		int rightmostNonZero = componentDist.length-1;
		while (componentDist[rightmostNonZero]==0.0) {
			rightmostNonZero -= 1;
			if (rightmostNonZero < 0) throw new IllegalArgumentException("Component distributions contains only zeros.");
		}
		Log.startTimer();
		double[] logComponentDist = new double[rightmostNonZero+1];
		for (int i=0; i<logComponentDist.length; ++i) logComponentDist[i]=Math.log(componentDist[i]);
		// current distribution resulting from i-fold convolution of componentDist
		// i.e. dist[x]=P(#matches=x|#clumps=i)
		// dist[k] has a special meaning: dist[k]=P(#matches>=k|#clumps=i)
		double[] logDist = new double[k+1];
		Arrays.fill(logDist, Double.NEGATIVE_INFINITY);
		System.arraycopy(logComponentDist, 0, logDist, 0, Math.min(logComponentDist.length,logDist.length));
		for (int i=k+1; i<componentDist.length; ++i) {
			logDist[k] = LogSpace.logAdd(logDist[k], Math.log(componentDist[i]));
		}
		double p = getLogProb(1) + logDist[k];
		double[] result = new double[k+1];
		Arrays.fill(result, Double.NEGATIVE_INFINITY);
		result[0] = getLogProb(0);
		double pPoisson = getLogProb(1);
		for (int j=0; j<=k; ++j) result[j]=LogSpace.logAdd(result[j], pPoisson+logDist[j]);
		for (int componentNumber=2; componentNumber<=k; ++componentNumber) {
			// Log.println(Log.Level.DEBUG, String.format("%d: %e ", i, p));
			double pLeave = Double.NEGATIVE_INFINITY;
			// convolve componentDist onto dist
			for (int j=k+logComponentDist.length-1; j>0; --j) {
				// recalculate dist[j]
				double dist_j = Double.NEGATIVE_INFINITY;
				for (int a=1; a<logComponentDist.length; ++a) {
					if (j-a<0) break;
					if (j-a<=k) dist_j=LogSpace.logAdd(dist_j, logComponentDist[a]+logDist[j-a]);
				}
				if (j<=k) logDist[j]=dist_j;
				else pLeave = LogSpace.logAdd(pLeave, dist_j);
			}
			// add probability dist[k]=P(#matches>=k|#clumps=i-1) as it would otherwise
			// "leave" the table
			logDist[k]=LogSpace.logAdd(logDist[k], pLeave);
			pPoisson = getLogProb(componentNumber);
			for (int j=0; j<=k; ++j) result[j]=LogSpace.logAdd(result[j], pPoisson+logDist[j]);
		}
		// add probability to have more than k clumps, which implies value>=k 
		// (by the assumption that componentDist[0]=0).
		// Log.println(Log.Level.DEBUG, "===");
		for (int i=k+1;;++i) {
//			Log.println(Log.Level.DEBUG, String.format("%d: %e ", i, p));
			double p_i = getLogProb(i);
			p=LogSpace.logAdd(p, p_i);
			if (i<lambda) continue;
			if (p-p_i>Math.log(1e14)) break;
		}
		return result;
	}

	/** Returns the complementary comulative distribution function (i.e. a table
	 *  of p-values). */
	public double[] getCCDF(int size) {
		double[] ccdf = new double[size];
		int n = pvaluePrecisionLimit();
		double pvalue = 0.0;
		for (;n>=0;--n) {
			pvalue+=getProb(n);
			if (n<size) ccdf[n]+=pvalue;
		}
		return ccdf;
	}
	
	/** Return P(X>=k) where X is Poisson distributed. */
	public double getTailProbability(int k) {
		double p = 0.0;
		int n = pvaluePrecisionLimit();
		for (;n>=k;--n) {
			p += getProb(n);
 		}
		return p;
	}
	
}
